package com.kingsoft.dc.khaos.module.spark.util

import java.text.SimpleDateFormat
import java.util

import com.kingsoft.dc.khaos.KhaosContext
import com.kingsoft.dc.khaos.extender.meta.api.{DmTableSplit, ReportDataStatusResult}
import com.kingsoft.dc.khaos.extender.meta.model.col.DmTableColumn
import com.kingsoft.dc.khaos.module.spark.constants.{ColumnType, MppSqlTypes, SchedulerConstants}
import com.kingsoft.dc.khaos.module.spark.metadata.sink.ExtractFieldInfo
import com.kingsoft.dc.khaos.module.spark.model.RelationDataStatusInfo
import com.kingsoft.dc.khaos.module.spark.model.cos.CosDataStatusInfo
import com.kingsoft.dc.khaos.module.spark.model.hdfs.HdfsDataStatusInfo
import com.kingsoft.dc.khaos.module.spark.model.ks3.Ks3DataStatusInfo
import com.kingsoft.dc.khaos.module.spark.request.model.{JdbcConnectEntity, StructFieldEntity}
import com.kingsoft.dc.khaos.module.spark.util.DataStatusUtils.IndicatorsEnum
import com.kingsoft.dc.khaos.util.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.{col, lit, trim}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
import org.apache.spark.util.LongAccumulator


import scala.collection.JavaConverters._
import scala.collection.{immutable, mutable}
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.util.Random

/**
 * Created by jing on 19/4/20.
 */
object DataframeUtils extends Logging {

  val VARCHAR = "varchar"
  val BIGINT = "bigint"
  val INTEGER = "integer"
  val DOUBLE = "double precision"
  val TIMESTAMP = "timestamp"
  val DATE = "date"
  val TIME = "time"
  val NUMERIC = "numeric"

  val CSV_DELIMITER: String = "\\."


  /**
   * 构建新的DataFrame,设置默认值,转换DataFrame字段名,适配字段类型
   *
   * @param tblList       分表列表
   * @param tableSplit    分表状态
   * @param tblName       表名
   * @param mppColsInfo   字段信息
   * @param mppFileds     前端字段信息
   * @param data          DataFrame
   * @param connectEntity 数据库连接
   * @return 重新构造的DataFrame
   */
  def buildNewDataframePGSql(tblName: String,
                             mppFileds: List[ExtractFieldInfo],
                             data: DataFrame,
                             connectEntity: JdbcConnectEntity) = {
    var targetData: DataFrame = data
    var tableSchema: mutable.HashMap[String, StructFieldEntity] = null

    tableSchema = connectEntity.getMppTableSchema(6000)
    if (tableSchema.isEmpty) {
      logError(s"can't found table struct,schema:[${connectEntity.getSchemaName()}], table:[${connectEntity.getTableName()}]")
      throw new Exception("未获取到表的元数据信息!")
    }
    // 默认值处理
    //    targetData = setDefaultValue(mppFileds, mppColsInfo, targetData)
    targetData.schema.fields.map(field => {
      val fieldName = field.name
      val targetField = tableSchema(fieldName)
      val targetType = targetField.getFieldType
      var dfType: DataType = null
      targetType.toLowerCase match {
        case MppSqlTypes.CHAR | MppSqlTypes.VARCHAR | MppSqlTypes.LONGVARCHAR | MppSqlTypes.TEXT => {
          dfType = StringType
        }
        case MppSqlTypes.INT | MppSqlTypes.INT2 | MppSqlTypes.INT4 | MppSqlTypes.INT8 | MppSqlTypes.TINYINT | MppSqlTypes.SMALLINT | MppSqlTypes.INTEGER => {
          dfType = LongType
        }
        case MppSqlTypes.BIGINT => {
          // orc不支持bigint
          dfType = LongType
        }
        case MppSqlTypes.FLOAT => {
          dfType = FloatType
        }
        case MppSqlTypes.DOUBLE => {
          dfType = DoubleType
        }
        case MppSqlTypes.DECIMAL | MppSqlTypes.NUMERIC => {
          //这里所有浮点型都转换为Double类型,原因是mpp对于外部表数据格式不支持decimal
          dfType = DoubleType
        }
        case MppSqlTypes.DATE | MppSqlTypes.DATETIME => {
          dfType = DateType
        }
        case MppSqlTypes.TIMESTAMP | MppSqlTypes.TIMESTAMPTZ => {
          dfType = TimestampType
        }
        case MppSqlTypes.TIME_WITH_TIMEZONE | MppSqlTypes.TIME_WITHOUT_TIMEZONE => {
          dfType = StringType
        }
        case MppSqlTypes.SERIAL => {
          dfType = IntegerType
        }
        case MppSqlTypes.BIGSERIAL => {
          dfType = LongType
        }
        case _ => {
          dfType = NullType
        }
      }
      targetData = targetData.withColumn(fieldName, targetData.col(fieldName).cast(dfType))
    }
    )
    val colArr = new ArrayBuffer[Column]()
    for (elem <- mppFileds) {
      // 有连线
      if (!elem.from_field.equals("")) {
        colArr += targetData.col(elem.field)
      } else {
        // 无连线
        if (!elem.data_type.equalsIgnoreCase("SERIAL4") &&
          !elem.data_type.equalsIgnoreCase("SERIAL8")) {
          colArr += targetData.col(elem.field)
        }
      }
    }
    // 筛选出需要的列
    targetData = targetData.select(colArr: _*)
    targetData.printSchema()
    (targetData, tableSchema)
  }

  /**
   * 构建新的DataFrame,设置默认值,转换DataFrame字段名,适配字段类型
   *
   * @param tblList       分表列表
   * @param tableSplit    分表状态
   * @param tblName       表名
   * @param mppColsInfo   字段信息
   * @param mppFileds     前端字段信息
   * @param data          DataFrame
   * @param connectEntity 数据库连接
   * @return 重新构造的DataFrame
   */
  def buildNewDataframeHAWQ(tblList: List[String],
                            tableSplit: DmTableSplit,
                            tblName: String,
                            mppColsInfo: util.List[DmTableColumn],
                            mppFileds: List[ExtractFieldInfo],
                            data: DataFrame,
                            connectEntity: JdbcConnectEntity) = {
    var targetData: DataFrame = data
    var tableSchema: mutable.HashMap[String, StructFieldEntity] = null
    if (tableSplit != null) { // 若开启分表,用分表列表第一张表获取schema
      tableSchema = connectEntity.getMppTrueTableSchema(tblList.head, 6000)
    } else { // 若未开启分表,用真实表名获取schema
      tableSchema = connectEntity.getMppTableSchema(6000)
    }
    if (tableSchema.isEmpty) {
      logError(s"can't found mpp table struct,schema:[${connectEntity.getSchemaName()}], table:[${connectEntity.getTableName()}]")
      throw new Exception("未获取到表的元数据信息!")
    }
    // 默认值处理
    //    targetData = setDefaultValue(mppFileds, mppColsInfo, targetData)
    targetData.schema.fields.map(field => {
      val fieldName = field.name
      val targetField = tableSchema(fieldName)
      val targetType = targetField.getFieldType
      var dfType: DataType = null
      targetType.toLowerCase match {
        case MppSqlTypes.CHAR | MppSqlTypes.VARCHAR | MppSqlTypes.LONGVARCHAR | MppSqlTypes.TEXT => {
          dfType = StringType
        }
        case MppSqlTypes.INT | MppSqlTypes.INT2 | MppSqlTypes.INT4 | MppSqlTypes.INT8 | MppSqlTypes.TINYINT | MppSqlTypes.SMALLINT | MppSqlTypes.INTEGER => {
          dfType = LongType
        }
        case MppSqlTypes.BIGINT => {
          // orc不支持bigint
          dfType = LongType
        }
        case MppSqlTypes.FLOAT => {
          dfType = FloatType
        }
        case MppSqlTypes.DOUBLE => {
          dfType = DoubleType
        }
        case MppSqlTypes.DECIMAL | MppSqlTypes.NUMERIC => {
          //这里所有浮点型都转换为Double类型,原因是mpp对于外部表数据格式不支持decimal
          dfType = DoubleType
        }
        case MppSqlTypes.DATE | MppSqlTypes.DATETIME => {
          dfType = DateType
        }
        case MppSqlTypes.TIMESTAMP | MppSqlTypes.TIMESTAMPTZ => {
          dfType = TimestampType
        }
        case MppSqlTypes.TIME_WITH_TIMEZONE | MppSqlTypes.TIME_WITHOUT_TIMEZONE => {
          dfType = StringType
        }
        case MppSqlTypes.SERIAL => {
          dfType = IntegerType
        }
        case MppSqlTypes.BIGSERIAL => {
          dfType = LongType
        }
        case _ => {
          dfType = NullType
        }
      }
      targetData = targetData.withColumn(fieldName, targetData.col(fieldName).cast(dfType))
    }
    )
    val colArr = new ArrayBuffer[Column]()
    for (elem <- mppFileds) {
      // 有连线
      if (!elem.from_field.equals("")) {
        colArr += targetData.col(elem.field)
      } else {
        // 无连线
        if (!elem.data_type.equalsIgnoreCase("SERIAL4") &&
          !elem.data_type.equalsIgnoreCase("SERIAL8")) {
          colArr += targetData.col(elem.field)
        }
      }
    }
    // 筛选出需要的列
    targetData = targetData.select(colArr: _*)
    targetData.printSchema()
    (targetData, tableSchema)
  }


  /**
   * 根据config的抽取字段生成mpp外部表字段定义
   *
   * @param extract_fields extract_fields 字段列表
   * @return 建表字段定义,比如a int, b varchar, c double precision
   */
  def getHAWQExtTableFieldDefined(extract_fields: List[ExtractFieldInfo]): String = {
    val fieldList = new Array[String](extract_fields.length)
    for (index <- extract_fields.indices) {
      val fieldType = extract_fields(index).data_type

      var targetType = VARCHAR
      fieldType match {
        case ColumnType.STRING => targetType = VARCHAR
        case ColumnType.NUMBER => targetType = BIGINT
        case ColumnType.DECIMAL => targetType = DOUBLE
        case ColumnType.TIME => targetType = VARCHAR
        case ColumnType.DATETIME => targetType = TIMESTAMP
        case ColumnType.DATE => targetType = DATE
        case "SERIAL4" => targetType = BIGINT
        case "SERIAL8" => targetType = BIGINT
        case _ => targetType = VARCHAR
      }
      fieldList(index) = s"""\"${extract_fields(index).field}\" ${targetType}"""
    }
    fieldList.mkString(",")
  }


  /**
   * 构建新的DataFrame,设置默认值,转换DataFrame字段名,适配字段类型
   *
   * @param tblList       分表列表
   * @param tableSplit    分表状态
   * @param tblName       表名
   * @param mppColsInfo   字段信息
   * @param mppFileds     前端字段信息
   * @param data          DataFrame
   * @param connectEntity 数据库连接
   * @return 重新构造的DataFrame
   */
  def buildNewDataframeHashData(tblList: List[String],
                                tableSplit: DmTableSplit,
                                tblName: String,
                                mppColsInfo: util.List[DmTableColumn],
                                mppFileds: List[ExtractFieldInfo],
                                data: DataFrame,
                                connectEntity: JdbcConnectEntity,
                                decimal_switch: Boolean = false) = {
    log.info(s"buildNewDataframeHashData decimal_switch: $decimal_switch")
    var targetData: DataFrame = data
    var tableSchema: mutable.HashMap[String, StructFieldEntity] = null
    if (tableSplit != null) { // 若开启分表,用分表列表第一张表获取schema
      tableSchema = connectEntity.getMppTrueTableSchema(tblList.head, 6000)
    } else { // 若未开启分表,用真实表名获取schema
      tableSchema = connectEntity.getMppTableSchema(6000)
    }
    if (tableSchema.isEmpty) {
      logError(s"can't found mpp table struct,schema:[${connectEntity.getSchemaName()}], table:[${connectEntity.getTableName()}]")
      throw new Exception("未获取到MPP表的元数据信息!")
    }
    // 默认值处理
    //    targetData = setDefaultValue(mppFileds, mppColsInfo, targetData)
    val cols = targetData.schema.fields.map(field => {
      val fieldName = field.name
      val targetField = tableSchema(fieldName.toLowerCase())
      val targetType = targetField.getFieldType
      val fieldSize = targetField.getFieldSize
      log.info("目标字段: " + fieldName + " 目标字段类型: " + targetType)
      var dfType: DataType = null
      targetType.toLowerCase match {
        case MppSqlTypes.CHAR | MppSqlTypes.VARCHAR | MppSqlTypes.LONGVARCHAR | MppSqlTypes.TEXT => {
          dfType = StringType
        }
        case MppSqlTypes.INT | MppSqlTypes.INT2 | MppSqlTypes.INT4 | MppSqlTypes.TINYINT | MppSqlTypes.SMALLINT | MppSqlTypes.INTEGER => {
          dfType = IntegerType
        }
        case MppSqlTypes.BIGINT | MppSqlTypes.INT8 => {
          dfType = LongType
        }
        case MppSqlTypes.FLOAT => {
          dfType = FloatType
        }
        case MppSqlTypes.DOUBLE => {
          dfType = DoubleType
        }
        case MppSqlTypes.DECIMAL | MppSqlTypes.NUMERIC => {
          //这里所有浮点型都转换为Double类型,原因是mpp对于外部表数据格式不支持decimal
          if (decimal_switch) {
            dfType = StringType
          } else {
            //不使用decimal
            dfType = DoubleType
          }
        }
        case MppSqlTypes.DATE | MppSqlTypes.DATETIME => {
          dfType = DateType
        }
        case MppSqlTypes.TIMESTAMP | MppSqlTypes.TIMESTAMPTZ => {
          dfType = TimestampType
        }
        case MppSqlTypes.TIME_WITH_TIMEZONE | MppSqlTypes.TIME_WITHOUT_TIMEZONE => {
          dfType = StringType
        }
        case MppSqlTypes.SERIAL => {
          dfType = IntegerType
        }
        case MppSqlTypes.BIGSERIAL => {
          dfType = LongType
        }
        case _ => {
          dfType = StringType
        }
      }
      log.info("目标字段: " + fieldName + " 目标字段长度: " + fieldSize + " 目标字段类型: " + targetType + " df字段类型: " + dfType)
      //      targetData = targetData.withColumn(fieldName, targetData.col(fieldName).cast(dfType))
      targetData.col(fieldName).cast(dfType)

    }
    )

    targetData = targetData.select(cols: _*)

    val colArr = new ArrayBuffer[Column]()
    for (elem <- mppFileds) {
      // 有连线
      if (!elem.from_field.equals("")) {
        colArr += targetData.col(elem.field)
      } else {
        // 无连线
        if (!elem.data_type.equalsIgnoreCase("SERIAL4") &&
          !elem.data_type.equalsIgnoreCase("SERIAL8")) {
          colArr += targetData.col(elem.field)
        }
      }
    }
    // 筛选出需要的列
    targetData = targetData.select(colArr: _*)
    (targetData, tableSchema)
  }

  /**
   * 构建新的DataFrame,设置默认值,转换DataFrame字段名,适配字段类型
   *
   * @param tblList       分表列表
   * @param tableSplit    分表状态
   * @param tblName       表名
   * @param mppColsInfo   字段信息
   * @param mppFileds     前端字段信息
   * @param data          DataFrame
   * @param connectEntity 数据库连接
   * @return 重新构造的DataFrame
   */
  def buildNewDataframeGreenplum(tblList: List[String],
                                 tableSplit: DmTableSplit,
                                 tblName: String,
                                 mppColsInfo: util.List[DmTableColumn],
                                 mppFileds: List[ExtractFieldInfo],
                                 data: DataFrame,
                                 connectEntity: JdbcConnectEntity, decimal_switch: Boolean = false) = {
    log.info(s"buildNewDataframeGreenplum decimal_switch: $decimal_switch")
    var targetData: DataFrame = data
    var tableSchema: mutable.HashMap[String, StructFieldEntity] = null
    if (tableSplit != null) { // 若开启分表,用分表列表第一张表获取schema
      logInfo("tblList.head" + tblList.head)
      tableSchema = connectEntity.getMppTrueTableSchema(tblList.head, 6000)
    } else { // 若未开启分表,用真实表名获取schema
      tableSchema = connectEntity.getMppTableSchema(6000)
    }
    if (tableSchema.isEmpty) {
      logError(s"can't found mpp table struct,schema:[${connectEntity.getSchemaName()}], table:[${connectEntity.getTableName()}]")
      throw new Exception("未获取到MPP表的元数据信息!")
    }
    // 默认值处理
    //    targetData = setDefaultValue(mppFileds, mppColsInfo, targetData)
    val colmap: Map[String, String] = mppColsInfo.asScala.map(col => {
      (col.getColName, col.getLength)
    }).toMap
    val cols = targetData.schema.fields.map(field => {
      val fieldName = field.name
      val targetField = tableSchema(fieldName)
      val targetType = targetField.getFieldType
      val fieldSize = targetField.getFieldSize
      log.info("目标字段: " + fieldName + " 目标字段类型: " + targetType)
      var dfType: DataType = null
      targetType.toLowerCase match {
        case MppSqlTypes.CHAR | MppSqlTypes.VARCHAR | MppSqlTypes.LONGVARCHAR | MppSqlTypes.TEXT => {
          dfType = StringType
        }
        case MppSqlTypes.INT | MppSqlTypes.INT2 | MppSqlTypes.INT4 | MppSqlTypes.TINYINT | MppSqlTypes.SMALLINT | MppSqlTypes.INTEGER => {
          dfType = IntegerType
        }
        case MppSqlTypes.BIGINT | MppSqlTypes.INT8 => {
          dfType = LongType
        }
        case MppSqlTypes.FLOAT => {
          dfType = FloatType
        }
        case MppSqlTypes.DOUBLE => {
          dfType = DoubleType
        }
        case MppSqlTypes.DECIMAL | MppSqlTypes.NUMERIC => {
          //这里所有浮点型都转换为Double类型,原因是mpp对于外部表数据格式不支持decimal
          if (decimal_switch) {
            val fieldLength: String = colmap(fieldName)
            val sizePs: Array[String] = fieldLength.split(",")
            if (sizePs.length == 2) {
              if (sizePs(0).trim.toInt > 15) {
                val p = sizePs(0).trim.toInt
                val s = sizePs(1).trim.toInt
                dfType = DecimalType(p, s)
              } else {
                dfType = DoubleType
              }
            } else {
              dfType = DoubleType
            }
          } else {
            //不使用decimal
            dfType = DoubleType
          }
        }
        case MppSqlTypes.DATE | MppSqlTypes.DATETIME => {
          dfType = DateType
        }
        case MppSqlTypes.TIMESTAMP | MppSqlTypes.TIMESTAMPTZ => {
          dfType = TimestampType
        }
        case MppSqlTypes.TIME_WITH_TIMEZONE | MppSqlTypes.TIME_WITHOUT_TIMEZONE => {
          dfType = StringType
        }
        case MppSqlTypes.SERIAL => {
          dfType = IntegerType
        }
        case MppSqlTypes.BIGSERIAL => {
          dfType = LongType
        }
        case _ => {
          dfType = StringType
        }
      }
      log.info("目标字段: " + fieldName + " 目标字段长度: " + fieldSize + " 目标字段类型: " + targetType + " df字段类型: " + dfType)
      //      targetData = targetData.withColumn(fieldName, targetData.col(fieldName).cast(dfType))
      targetData.col(fieldName).cast(dfType)

    }
    )

    targetData = targetData.select(cols: _*)

    val colArr = new ArrayBuffer[Column]()
    for (elem <- mppFileds) {
      // 有连线
      if (!elem.from_field.equals("")) {
        colArr += targetData.col(elem.field)
      } else {
        // 无连线
        if (!elem.data_type.equalsIgnoreCase("SERIAL4") &&
          !elem.data_type.equalsIgnoreCase("SERIAL8")) {
          colArr += targetData.col(elem.field)
        }
      }
    }
    // 筛选出需要的列
    targetData = targetData.select(colArr: _*)
    (targetData, tableSchema)
  }

  def setJDBCDefaultValue(sinkSchema: List[ExtractFieldInfo], columnEntiy: util.List[DmTableColumn], data: DataFrame): DataFrame = {
    //生成[字段名,not_null(true/false)]的map
    val fieldAndNotNull: Map[String, String] = columnEntiy.asScala.map(colEntiy => {
      val colName: String = colEntiy.getColName
      var not_null: String = ""
      colEntiy.getParams.asScala.foreach(map => {
        map.get("pKey") match {
          case "NOT_NULL" => not_null = map.get("pValue")
          case _ =>
        }
      })
      (colName, not_null)
    }).toMap
    log.info("fieldAndNotNull ==> " + fieldAndNotNull.mkString(","))
    log.info("sinkSchema ==> " + sinkSchema.map(_.field).mkString(","))
    log.info("dataFrame ==> " + data.columns.mkString(","))

    null

  }

  /**
   * 转换DF
   *
   * 修改字段名以及填充默认值
   *
   * 返回的DF所有字段类型都为StringType
   *
   * @param data
   * @return
   */
  def setDefaultValue(sinkSchema: List[ExtractFieldInfo], columnEntiy: util.List[DmTableColumn], data: DataFrame): DataFrame = {
    //生成[字段名,not_null(true/false)]的map
    val fieldAndNotNull: Map[String, String] = columnEntiy.asScala.map(colEntiy => {
      val colName: String = colEntiy.getColName
      var not_null: String = ""
      colEntiy.getParams.asScala.foreach(map => {
        map.get("pKey") match {
          case "NOT_NULL" => not_null = map.get("pValue")
          case _ =>
        }
      })
      (colName, not_null)
    }).toMap
    log.info("fieldAndNotNull ==> " + fieldAndNotNull.mkString(","))
    log.info("sinkSchema ==> " + sinkSchema.map(_.field).mkString(","))
    log.info("dataFrame ==> " + data.columns.mkString(","))
    //只复制有连线关系的Column
    val colArr = new ArrayBuffer[Column]()
    for (ef <- sinkSchema) {
      if (!ef.from_field.trim.equals("")) {
        val to_field: String = ef.field
        val from_field: String = ef.from_field
        colArr += data.col(from_field) as (to_field)
      }
    }

    if (colArr.isEmpty) {
      var isError = true
      sinkSchema.foreach(schema => {
        if (!"".equals(schema.field_props.default_value)) {
          isError = false
        }
      })
      if (isError) {
        throw new Exception("作业配置异常,目标表没有连接上游字段,且默认值都为空!")
      }
    }

    //    log.info("==> 111111")
    var value: DataFrame = data.select(colArr: _*)
    //对目标DF进行默认值填充和类型转换，填充前需要先转为String类型

    // ############# fix withColumn begin ##################

    var cols = sinkSchema.map(ef => {
      var column: Column = null
      val to_field: String = ef.field
      val data_type = ef.data_type
      val from_field: String = ef.from_field
      val default_value: String = ef.field_props.default_value

      if (!from_field.trim.equals("")) { // 有连线
        //将time类型的值去掉空格，否则写入mysql会报错
        if (data_type.equalsIgnoreCase("TIME")) {
          column = trim(value.col(to_field).cast(StringType)) as to_field
        } else {
          column = value.col(to_field).cast(StringType)
        }

      } else { // 无连线
        if (!data_type.equalsIgnoreCase("SERIAL4") && !data_type.equalsIgnoreCase("SERIAL8")) {
          //          value = value.withColumn(to_field, lit(null).cast(StringType))
          column = lit(null).cast(StringType).as(to_field)
        } /*else{
          //兼容mpp中的自增主键和自增长主键
          column = lit(null).cast(StringType).as(to_field)
        }*/
      }
      column
    })

    //    log.info("==> 222222")
    //mpp中的自增主键和自增长主键在上面循环结束后,输出的Column为null
    cols = cols.filter(_ != null)
    value = value.select(cols: _*)

    for (ef <- sinkSchema) {
      val to_field: String = ef.field
      val data_type = ef.data_type
      val from_field: String = ef.from_field
      val no_check: Boolean = ef.no_check.get
      val default_value: String = ef.field_props.default_value

      if (!default_value.equals("")) {
        value = value.na.fill(default_value, Array(to_field))

      } else if (!data_type.equalsIgnoreCase("SERIAL4")
        && !data_type.equalsIgnoreCase("SERIAL8")
        && default_value.equals("")
        && fieldAndNotNull(to_field).equalsIgnoreCase("true")
        && from_field.trim.equals("")
        && !no_check) {

        log.error(s"目标字段：${to_field}不能为null!")
        throw new Exception(s"目标字段：${to_field}不能为null!")
      }
    }
    //    log.info("==> 333333")
    value




    // ############# fix withColumn  end ##################

    /* for (ef <- sinkSchema) {
       val to_field: String = ef.field
       val data_type = ef.data_type
       val field_length = ef.length
       val from_field: String = ef.from_field
       val default_value: String = ef.field_props.default_value
       if (!from_field.trim.equals("")) {
         //将time类型的值去掉空格，否则写入mysql会报错
         if (data_type.equalsIgnoreCase("TIME")) {
           value = value.withColumn(to_field, trim(value.col(to_field).cast(StringType)))

         } else {
           value = value.withColumn(to_field, value.col(to_field).cast(StringType))
         }
       } else {
         if (!data_type.equalsIgnoreCase("SERIAL4") &&
           !data_type.equalsIgnoreCase("SERIAL8")) {
           value = value.withColumn(to_field, lit(null).cast(StringType))
         }
       }


       //填充默认值
       if (!default_value.equals("")) {
         value = value.na.fill(default_value, Array(to_field))
       } else if (!data_type.equalsIgnoreCase("SERIAL4") &&
         !data_type.equalsIgnoreCase("SERIAL8") &&
         default_value.equals("") &&
         fieldAndNotNull(to_field).equalsIgnoreCase("true") &&
         from_field.trim.equals("")) {
         log.error(s"目标字段：${to_field}不能为null!")
         throw new Exception(s"目标字段：${to_field}不能为null!")
       }
     }
     value*/
  }

  def setHdfsDefaultValue(sinkSchema: List[ExtractFieldInfo], data: DataFrame): DataFrame = {
    //只复制有连线关系的Column
    val colArr = new ArrayBuffer[Column]()
    for (ef <- sinkSchema) {
      if (!ef.from_field.trim.equals("")) {
        val to_field: String = ef.field
        val from_field: String = ef.from_field
        colArr += data.col(from_field) as (to_field)
      }
    }

    if (colArr.isEmpty) {
      var isError = true
      sinkSchema.foreach(schema => {
        if (!"".equals(schema.field_props.default_value)) {
          isError = false
        }
      })
      if (isError) {
        throw new Exception("作业配置异常,目标表没有连接上游字段,且默认值都为空!")
      }
    }

    var value: DataFrame = data.select(colArr: _*)

    for (ef <- sinkSchema) {
      val to_field: String = ef.field
      val data_type = ef.data_type
      val from_field: String = ef.from_field
      val default_value: String = ef.field_props.default_value

      if (!default_value.equals("")) {
        value = value.na.fill(default_value, Array(to_field))

      } else if (default_value.equals("")
        && from_field.trim.equals("")) {
        log.error(s"目标字段：${to_field}不能为null!")
        throw new Exception(s"目标字段：${to_field}不能为null!")
      }
    }
    value
  }

  /** 转换DF字段类型 具体转换查看 getDataType()方法 */
  def convertDataType(sinkSchema: List[ExtractFieldInfo], data: DataFrame): DataFrame = {
    /*var value: DataFrame = data
    for (ef <- sinkSchema) {
      val field: String = ef.field
      val data_type: String = ef.data_type
      value = value.withColumn(field, value.col(field).cast(getDataType(data_type)))
    }
    value*/

    val cols: immutable.Seq[Column] = sinkSchema.map((ef: ExtractFieldInfo) => {
      val field: String = ef.field
      val data_type: String = ef.data_type
      data.col(field).cast(getDataType(data_type))
    })

    data.select(cols: _*)
  }


  /** 转换DF字段类型 具体转换查看 getDataType()方法 */
  def convertDataType4Hive(sinkSchema: List[ExtractFieldInfo], data: DataFrame): DataFrame = {
    /*var value: DataFrame = data
    for (ef <- sinkSchema) {
      val field: String = ef.field
      val data_type: String = ef.data_type
      value = value.withColumn(field, value.col(field).cast(getDataType(data_type)))
    }
    value*/

    val cols: immutable.Seq[Column] = sinkSchema.map((ef: ExtractFieldInfo) => {
      val field: String = ef.field
      val length: String = ef.length.get
      val data_type: String = ef.data_type
      data.col(field).cast(getDataType(data_type, length))
    })

    data.select(cols: _*)
  }


  /** 获得ROW中对应的值的类型 */
  def getDataType(dataType: String): DataType = {
    var value: DataType = null
    value = dataType match {
      case ColumnType.STRING => DataTypes.StringType
      case ColumnType.NUMBER => DataTypes.LongType
      case ColumnType.DATE => DataTypes.DateType
      case ColumnType.DECIMAL => DataTypes.DoubleType
      case ColumnType.TIME => DataTypes.StringType
      case ColumnType.DATETIME => DataTypes.TimestampType
      case _ => DataTypes.NullType
    }
    value
  }

  /** 获得ROW中对应的值的类型 */
  def getDataType(dataType: String, length: String): DataType = {
    var value: DataType = null
    value = dataType match {
      case ColumnType.STRING => DataTypes.StringType
      case ColumnType.NUMBER => DataTypes.LongType
      case ColumnType.DATE => DataTypes.DateType
      case ColumnType.DECIMAL => { // 通过数据管理字段长度,转成精度更高的decimal类型
        val strings = length.split(",")
        val p = strings(0).trim.toInt
        if (p > 15) {        //解决精度丢失问题 number->（数据管理） decimal ->（数据集成）double  oracle 同步到hive 精度丢失
          val s = strings(1).trim.toInt
          DataTypes.createDecimalType(p, s)
        } else {
          DataTypes.DoubleType
        }
      }
      case ColumnType.TIME => DataTypes.StringType
      case ColumnType.DATETIME => DataTypes.TimestampType
      case _ => DataTypes.NullType
    }
    log.info("datatype:{}", value)
    value
  }

  /**
   * phoenix 字段类型转换
   *
   * @param sinkSchema
   * @param data
   * @return
   */
  def convertDataType4Phoenix(sinkSchema: List[ExtractFieldInfo], data: DataFrame, decimalSwitch: Boolean): DataFrame = {
    var value: DataFrame = data

    /*  for (ef <- sinkSchema) {
        val field: String = ef.field
        val data_type: String = ef.data_type
        value = value.withColumn(field, value.col(field).cast(getDataType4Phoenix(data_type)))
      }
      value*/
    val cols = sinkSchema.map(ef => {
      var column: Column = null
      val field: String = ef.field
      val data_type: String = ef.data_type
      val length = ef.length.get
      value.col(field).cast(getDataType4Phoenix(data_type, length, decimalSwitch))
    })

    value.select(cols: _*)
  }

  /**
   * phoenix 字段映射
   *
   * @param dataType
   * @return
   */
  def getDataType4Phoenix(dataType: String, length: String, decimalSwitch: Boolean): DataType = {
    var value: DataType = null
    value = dataType match {
      case ColumnType.STRING => DataTypes.StringType
      case ColumnType.NUMBER => DataTypes.LongType
      case ColumnType.DATE => DataTypes.DateType
      case ColumnType.DECIMAL => { // 通过数据管理字段长度,转成精度更高的decimal类型
        if (decimalSwitch) {
          val strings = length.split(",")
          val p = strings(0).trim.toInt
          if (p > 15) {
            val s = strings(1).trim.toInt
            DataTypes.createDecimalType(p, s)
          } else {
            DataTypes.DoubleType
          }
        } else {
          DataTypes.DoubleType
        }
      }
      case ColumnType.TIME => DataTypes.StringType
      case ColumnType.DATETIME => DataTypes.TimestampType
      case _ => DataTypes.NullType
    }
    value
  }


  /** 转换DF字段类型 具体转换查看 getDataType4HBase()方法 -->  HBase用 */
  def convertDataType4HBase(sinkSchema: List[ExtractFieldInfo], data: DataFrame): DataFrame = {
    var value: DataFrame = data
    /*for (ef <- sinkSchema) {
      val field: String = ef.field
      val data_type: String = ef.data_type
      value = value.withColumn(field, value.col(field).cast(getDataType4HBase(data_type)))
    }
    value*/

    val cols = sinkSchema.map(ef => {
      var column: Column = null
      val field: String = ef.field
      val data_type: String = ef.data_type
      value.col(field).cast(getDataType4HBase(data_type))
    })

    value.select(cols: _*)
  }

  /** 获得ROW中对应的值的类型 --> HBase用 */
  def getDataType4HBase(dataType: String): DataType = {
    var value: DataType = null
    value = dataType match {
      case ColumnType.STRING => DataTypes.StringType
      case ColumnType.NUMBER => DataTypes.LongType
      case ColumnType.DATE => DataTypes.DateType
      case ColumnType.DECIMAL => DataTypes.StringType
      case ColumnType.TIME => DataTypes.StringType
      case ColumnType.DATETIME => DataTypes.TimestampType
      case _ => DataTypes.NullType
    }
    value
  }

  /**
   *
   * 根据数据管理字段信息 重新排列DataFrame字段顺序
   * */
  def sortDataCol(data: DataFrame, dmTableColumnList: util.List[DmTableColumn]): DataFrame = {
    val cols: mutable.Buffer[Column] = dmTableColumnList.asScala.map(col => new Column(col.getColName))
    val sortColData: DataFrame = data.select(cols: _*)
    sortColData
  }

  /**
   *
   * 根据前端字段信息 重新排列DataFrame字段顺序
   * */
  def sortDataCol(data: DataFrame, columnInfoMetaList: List[ExtractFieldInfo]): DataFrame = {
    val cols: List[Column] = columnInfoMetaList.map(col => new Column(col.field))
    val sortColData: DataFrame = data.select(cols: _*)
    sortColData
  }

  /**
   * 根据config的抽取字段生成mpp外部表字段定义
   *
   * @param extract_fields extract_fields 字段列表
   * @return 建表字段定义,比如a int, b varchar, c double precision
   */
  def getMppExtTableFieldDefined(extract_fields: List[ExtractFieldInfo], decimal_switch: Boolean = false): String = {
    log.info(s"getMppExtTableFieldDefined decimal_switch: $decimal_switch")
    val fieldList = new ArrayBuffer[String]()
    for (index <- extract_fields.indices) {
      var targetType = VARCHAR
      val fieldType = extract_fields(index).data_type
      val from_field = extract_fields(index).from_field
      // 有连线
      if (!from_field.equals("")) {
        fieldType match {
          case ColumnType.STRING => targetType = VARCHAR
          case ColumnType.NUMBER => targetType = BIGINT
          case ColumnType.DECIMAL => {
            if (decimal_switch) {
              targetType = VARCHAR
            } else {
              //不使用decimal
              targetType = DOUBLE
            }
          }
          case ColumnType.TIME => targetType = TIME
          case ColumnType.DATETIME => targetType = TIMESTAMP
          case ColumnType.DATE => targetType = DATE
          case "SERIAL4" => targetType = INTEGER
          case "SERIAL8" => targetType = BIGINT
          case _ => targetType = VARCHAR
        }
        fieldList += s"${extract_fields(index).field} ${targetType}"
      } else {
        // 目标字段不是自增字段
        if (!fieldType.equalsIgnoreCase("SERIAL4")
          && !fieldType.equalsIgnoreCase("SERIAL8")) {
          fieldType match {
            case ColumnType.STRING => targetType = VARCHAR
            case ColumnType.NUMBER => targetType = BIGINT
            case ColumnType.DECIMAL => {
              if (decimal_switch) {
                //使用decimal
                targetType = VARCHAR
              } else {
                //不使用decimal
                targetType = DOUBLE
              }
            }
            case ColumnType.TIME => targetType = TIME
            case ColumnType.DATETIME => targetType = TIMESTAMP
            case ColumnType.DATE => targetType = DATE
            case "SERIAL4" => targetType = INTEGER
            case "SERIAL8" => targetType = BIGINT
            case _ => targetType = VARCHAR
          }
          fieldList(index) = s"""\"${extract_fields(index).field}\" ${targetType}"""
        }
      }
    }
    fieldList.mkString(",")
  }

  /**
   * 获取数据表字段名
   *
   * @param colList 字段
   * @return
   */
  def getTableFieldNames(colList: List[StructField]): List[String] = {
    val len = colList.size
    var fieldList = new Array[String](len)
    for (index <- 0 until len) {
      fieldList(index) = s"${colList(index).name}"
    }
    fieldList.toList
  }

  /**
   * 根据源表字段类型和目标表字段类型，获取MPP表
   *
   * @param dfColList 写入外部表数据的dataframe
   * @return 外部表导入内部表的sql语句，如下
   * {{{
   *                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                insert into innerTable (c1, c2, c3) AS (select to_timestamp(c1,'yyyy-MM-dd'), to_timestamp(c2, 'yyyy-MM-dd hh24:mi:ss'), c3 from externalTable )
   *                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     insert into innerTable (c1, c2, c3) AS (select to_timestamp(c1,'yyyy-MM-dd'), to_timestamp(c2, 'yyyy-MM-dd hh24:mi:ss'), c3 from externalTable )
   * }}}
   */
  def getFieldConvertExps(tableSplit: DmTableSplit,
                          tblNameAndDF: mutable.HashMap[String, DataFrame],
                          dfColList: List[StructField],
                          connectEntity: JdbcConnectEntity,
                          tblSchema: mutable.HashMap[String, StructFieldEntity]): List[String] = {
    val len = dfColList.size
    val fieldList = new Array[String](len)
    for (index <- 0 until len) {
      val colName = dfColList(index).name
      val mppColStruct = tblSchema.get(colName).get
      val mppColType = mppColStruct.getFieldType
      var convertStr = s"CAST(${colName} AS ${mppColType})"
      mppColType match {
        // case MppSqlTypes.DATE => convertStr = s"to_timestamp(${colName.toLowerCase()}, 'yyyy-MM-dd')"
        //        case MppSqlTypes.TIMESTAMP => convertStr = s"to_timestamp(${colName}, 'yyyy-MM-dd hh24:mi:ss')"

        //        case MppSqlTypes.TIME_WITH_TIMEZONE => convertStr = s"make_time()"
        //        case MppSqlTypes.NUMERIC | MppSqlTypes.DECIMAL => convertStr = s"CAST(${colName} AS ${mppColType}"
        case "serial" => convertStr = s"CAST(${colName} AS integer)"
        case "bigserial" => convertStr = s"CAST(${colName} AS bigint)"
        case _ => convertStr = s"CAST(${colName} AS ${mppColType})"
      }
      fieldList(index) = s"${convertStr}"
    }
    fieldList.toList
  }

  /**
   * 重分区
   *
   * @param inputDF
   * @param nums
   * @return
   */
  def repartionDataframe[T](inputDF: DataFrame, nums: Int, clazz: T = this): DataFrame = {
    val nowNums = inputDF.rdd.partitions.size
    if (nowNums >= nums) { //无shuffle
      inputDF.coalesce(nums)
    } else {
      logInfo(s"==> Current operations contains shuffle ![${clazz.getClass}]")
      inputDF.repartition(nums)
    }
  }

  /**
   * 预估df 大小 单位mb
   *
   * @param df
   * @return
   */
  def estimateDFSize(df: DataFrame, columnInfoMetaList: List[ExtractFieldInfo]): (Long, Int) = {
    var oneRowSize = 0
    var oneColNameSize = 0
    val oneColFamilySize = 2 // '0'
    columnInfoMetaList.foreach(c => {
      oneColNameSize += c.field.getBytes.length
    })

    df.schema.fields.foreach(field => {
      oneRowSize += field.dataType.defaultSize
    })
    val count = df.count()
    val dfSizeMb = (count * (oneRowSize + oneColNameSize + oneColFamilySize) / 1024 / 1024).toInt
    logInfo(s"=> estimateDFSize: ($count rows, ${dfSizeMb} mb)")
    (count, dfSizeMb)
  }

  /**
   * 预估df 大小 单位mb
   *
   * @param df
   * @return
   */
  def estimateDFSize(df: DataFrame): (Long, Int) = {
    var oneRowSize = 0
    df.schema.fields.foreach(field => {
      oneRowSize += field.dataType.defaultSize
    })
    val count = df.count()
    val dfSizeMb = (count * oneRowSize / 1024 / 1024).toInt
    logInfo(s"=> estimateDFSize: ($count rows, ${dfSizeMb} mb)")
    (count, dfSizeMb)
  }

  /**
   * 缓存Dataframe
   *
   * @param inputdf
   * @param storageLevel
   * @tparam T
   * @return
   */
  def cacheDataframe[T](inputdf: DataFrame, Clazz: T, storageLevel: StorageLevel = MEMORY_AND_DISK): DataFrame = {
    logInfo(s"==> [${Clazz.getClass}] output dataframe had cached!")
    inputdf.persist(storageLevel)
  }

  /** 定长加载 */
  def Loading(cosRDD: RDD[String], columns: java.util.List[DmTableColumn]): RDD[Array[String]] = {
    var arrRDD: RDD[Array[String]] = null

    import scala.collection.JavaConverters._
    // 存储每个字段长度的集合
    val list = ListBuffer[String]()
    columns.asScala.map(column => {
      var length: String = column.getLength
      //如果包含逗号, 则字段类型为decimal(16,2) 取第一个长度(16)加小数点(1)的长度为最终长度
      if (length.contains(",")) {
        val lengthArr: Array[String] = length.split(",")
        length = (lengthArr(0).toInt + 1).toString
      }
      list += length
    })
    val col_size: Int = columns.size()
    try {
      arrRDD = cosRDD.map(line => {
        //存储定长后的字段值
        val colArr = new Array[String](col_size)
        var index: Int = 0

        for (i <- list.indices) {
          val length: Int = list(i).toInt
          //取每个定长字段的值
          colArr.update(i, line.substring(index, length + index))
          index += length
        }
        colArr
      })
    } catch {
      case e: Exception => {
        e.printStackTrace()
        println(s"定长加载失败,失败信息: ${e.getMessage}, 失败原因: ${e.getCause}")
        throw new Exception(s"定长加载失败,失败信息: ${e.getMessage}, 失败原因: ${e.getCause}")
      }
    }

    arrRDD
  }

  /** 定长加载 */
  def LoadingRow(cosRDD: RDD[String], columns: java.util.List[DmTableColumn]): RDD[Array[String]] = {
    import scala.collection.JavaConverters._
    var arrRDD: RDD[Array[String]] = null
    // 存储每个字段长度的集合
    val list: ListBuffer[String] = ListBuffer[String]()
    columns.asScala.map((column: DmTableColumn) => {
      var length: String = column.getLength
      //如果包含逗号, 则字段类型为decimal(16,2) 取第一个长度(16)加小数点(1)的长度为最终长度
      if (length.contains(",")) {
        val lengthArr: Array[String] = length.split(",")
        length = (lengthArr(0).toInt + 1).toString
      }
      list += length
    })
    val col_size: Int = columns.size()
    try {
      arrRDD = cosRDD.map((line: String) => {
        //存储定长后的字段值
        val colArr = new Array[String](col_size)
        var index: Int = 0

        for (i <- list.indices) {
          val length: Int = list(i).toInt
          //取每个定长字段的值
          colArr.update(i, line.substring(index, length + index))
          index += length
        }
        colArr
      })
    } catch {
      case e: Exception => {
        e.printStackTrace()
        println(s"定长加载失败,失败信息: ${e.getMessage}, 失败原因: ${e.getCause}")
        throw new Exception(s"定长加载失败,失败信息: ${e.getMessage}, 失败原因: ${e.getCause}")
      }
    }

    arrRDD
  }

  /** 定长加载 */
  def Loading(cosRDD: RDD[String], columns: java.util.List[DmTableColumn], increment: Int): RDD[Array[String]] = {
    var arrRDD: RDD[Array[String]] = null

    import scala.collection.JavaConverters._
    // 存储每个字段长度的集合
    val list = ListBuffer[String]()
    columns.asScala.map(column => {
      var length: String = column.getLength
      //如果包含逗号, 则字段类型为decimal(16,2) 取第一个长度(16)加小数点(1)的长度为最终长度
      if (length.contains(",")) {
        val lengthArr: Array[String] = length.split(",")
        length = (lengthArr(0).toInt + 1).toString
      }
      list += length
    })
    val col_size: Int = columns.size()
    try {
      arrRDD = cosRDD.map(line => {
        //存储定长后的字段值
        val colArr = new Array[String](col_size)
        var index: Int = 0

        for (i <- list.indices) {
          val length: Int = list(i).toInt + increment
          //取每个定长字段的值
          colArr.update(i, line.substring(index, length + index))
          index += length
        }
        colArr
      })
    } catch {
      case e: Exception => {
        e.printStackTrace()
        println(s"定长加载失败,失败信息: ${e.getMessage}, 失败原因: ${e.getCause}")
        throw new Exception(s"定长加载失败,失败信息: ${e.getMessage}, 失败原因: ${e.getCause}")
      }
    }

    arrRDD
  }

  /** 定长卸载 */
  def UnLoading(cosDF: DataFrame, columns: java.util.List[DmTableColumn], CSV_DELIMITER: String): DataFrame = {
    var data: DataFrame = cosDF

    import scala.collection.JavaConverters._

    //转所有字段类型为string
    val cols: mutable.Buffer[Column] = columns.asScala.map(c => col(c.getColName).cast(StringType))
    data = data.select(cols: _*)

    //存储(字段名,字段长度)的map
    val map = mutable.Map[String, String]()
    columns.asScala.map(column => {
      map += (column.getColName -> column.getLength)
    })

    val df_cols: Array[String] = data.columns
    try {
      val rdd: RDD[String] = data.rdd.map(row => {
        var str = ""
        val strArr: Array[String] = df_cols.map(field_name => {
          var value: String = row.getAs[String](field_name)
          var length: Int = 0
          //如果包含逗号, 则字段类型为decimal(16,2) 取第一个长度(16)加小数点(1)的长度为最终长度
          if (map(field_name).contains(",")) {
            val lengthArr: Array[String] = map(field_name).split(",")
            length = lengthArr(0).toInt + 1
          } else {
            length = map(field_name).toInt
          }
          //补对应长度的空格
          if (value.length < length) {
            value = Array.fill(length - value.length)(" ").mkString + value
          } else if (value.length > length) {
            //字段值大于指定长度时,抛错
            throw new IllegalArgumentException(s"定长卸载失败: 源字段${field_name}长度为${value.length},大于卸载指定长度${length}")
          }
          value
        })
        //开启定长且有分隔符,用分隔符连接字段
        if (null != CSV_DELIMITER && !CSV_DELIMITER.isEmpty) {
          str = strArr.mkString(CSV_DELIMITER)
        } else {
          str = strArr.mkString
        }
        str
      })
      val tmp_spark: SparkSession = data.sparkSession
      import tmp_spark.implicits._
      //生成只有一个字段的DF
      data = rdd.toDF("value")
    } catch {
      case e: Exception => {
        e.printStackTrace()
        println(s"定长卸载失败,失败信息: ${e.getMessage}, 失败原因: ${e.getCause}")
        throw new Exception(s"定长卸载失败,失败信息: ${e.getMessage}, 失败原因: ${e.getCause}")
      }
    }
    data
  }

  /**
   * hivesink估算比较合理的分区数，优化小文件问题
   *
   * @param kc
   * @param data 数据
   * @return 分区数
   */
  def estimatePartitions(kc: KhaosContext, data: DataFrame, columnsInfo: util.List[DmTableColumn], accumValue: Long): Int = {
    val _kc: KhaosContext = kc
    var dataNum = 0
    // 取数据表定义数据字段长度粗略估算每条数据所占空间
    val fieldsSizeList = columnsInfo.asScala.map(colEntiy => {
      var length: String = colEntiy.getLength
      // 浮点数为(10,2),只取10
      if (length.contains(",")) {
        length = length.split(",")(0)
      }
      length
    }).toList

    for (elem <- fieldsSizeList) {
      // 字段没有长度,给默认50
      if (elem == null || elem == "") {
        dataNum += 50
      } else {
        dataNum += elem.toInt
      }
    }
    // 每条数据的大小大约为(单位)
    val everyDataNum: Long = dataNum
    // 总的数据条数
    //    val numCount = calculateDataNum(kc, data, "repatitions")._2.value.toLong
    //    val numCount = data.count()
    // 总的数据量(单位b)
    val totalData: Long = everyDataNum * accumValue
    // 分区数为数据量(单位b)/core总数/单个task处理数据量200M(根据HDFS block块推算1-2倍大小)
    var partitions: Long = Math.floorDiv(Math.floorDiv(Math.floorDiv(totalData, 1024), 1024), 200)

    //    //尽量保证重分区不产生额外的shuffle
    //    val actualPartitions = data.rdd.getNumPartitions
    //    if(actualPartitions < partitions){
    //      partitions = actualPartitions
    //    }

    if (partitions > 200) { //200为SparkSQL默认shuffle处理分区数 spark.sql.shuffle.partitions
      partitions = 200
    } else if (partitions < 1) {
      partitions = 1
    }

    log.info("分区数: " + partitions.toInt)
    partitions.toInt
  }

  /** Mysql、OracleSource,sqlserver
   * 估算比较合理的分区数，优化小文件问题
   *
   * @param _jdbc_sharding_size 单个task处理数据量
   * @return 分区数
   */
  def estimateTaskSegmentation(count_num: Long, dataFrame: DataFrame, _jdbc_sharding_size: Int): Int = {
    //    var odds = 1.0
    var lenCount = 0
    var avglen = 0l
    //    dataFrame.show(1)
    //    if (count_num != 0) {
    //      odds = 1 / Math.pow(2, count_num.toString.length)
    //      if (odds > 1.0) {
    //        odds = 1.0
    //      }
    //      while (lenCount == 0) {
    //        implicit val formats = DefaultFormats
    //        val value: Dataset[Row] = dataFrame.sample(true, odds)
    //        value.show()
    //        val arr = value.collect()
    //        log.info("抽样条数：" + arr.length)
    //        for (i <- 0 to arr.length - 1) {
    //          for (j <- 0 to arr(i).length - 1) {
    //            lenCount += arr(i)(j).toString.length
    //          }
    //        }
    //        if (arr.length > 0)
    //          avglen = Math.ceil(lenCount / arr.length.toDouble).toLong
    //      }
    //    }
    val arr: Array[Row] = dataFrame.collect()
    log.info("抽样条数：" + arr.length)
    for (i <- 0 to arr.length - 1) {
      for (j <- 0 to arr(i).length - 1) {
        var fieldValue: Any = arr(i)(j)
        if (Option(fieldValue).isEmpty) {
          //log.warn("遍历第{}次样例数据：列个数arr(i).length={}\t第{}列、值={}", i, arr(i).length, j, "该列为空")
        } else {
          //log.info("遍历第{}次样例数据：列个数arr(i).length={}\t第{}列、值={}",i,arr(i).length,j,fieldValue)
          lenCount += fieldValue.toString.length
        }
      }
    }
    if (arr.length > 0)
      avglen = Math.ceil(lenCount / arr.length.toDouble).toLong


    log.info("总数据条数：" + count_num)
    log.info("每条数据平均大小：" + avglen)
    val totalData: Long = count_num * avglen
    log.info("总数据量(位)：" + totalData)
    // 切分数为数据量(单位b)/core总数/单个task处理数据量200M(根据HDFS block块推算1-2倍大小)
    var segmentations: Long = Math.ceil(Math.ceil(Math.ceil(totalData / 8d / 1024d) / 1024d) / _jdbc_sharding_size.toDouble).toLong
    if (segmentations.toInt > 200) {
      segmentations = 200
    } else if (segmentations.toInt < 1) {
      segmentations = 1
    }
    log.info("分区数: " + segmentations.toInt)
    segmentations.toInt
  }

  /** influxdb
   * 估算比较合理的分区数，优化小文件问题
   *
   * @param kc
   * @param data 数据
   * @return 分区数
   */
  def estimateInfluxdbPartitions(kc: KhaosContext, data: DataFrame, influxColsInfo: util.List[DmTableColumn], count_num: Long, _repartition_per_nums: Int, _max_partition_nums: Int): Int = {
    val _kc: KhaosContext = kc
    var dataNum = 0

    // 取数据表定义数据字段长度粗略估算每条数据所占空间

    val fieldsSizeList = influxColsInfo.asScala.map(colEntiy => {
      var length: String = null
      if (colEntiy.getColType != "Long" && colEntiy.getColType != "Time") {
        length = colEntiy.getLength
        log.info("字段长度：" + length)
        // 浮点数为(10,2),只取10
        if (length.contains(",")) {
          length = length.split(",")(0)
        }
      }
      length
    }).toList

    for (elem <- fieldsSizeList) {
      //字段没有长度,给默认50
      if (elem == null || elem == "") {
        dataNum += 50
      } else if (elem.toLong > 65535) {
        dataNum += 65535
      } else {
        dataNum += elem.toInt
      }
    }

    // 每条数据的大小大约为(单位)
    val everyDataNum: Long = dataNum
    log.info("每条数据大小：" + everyDataNum.toInt)
    // 总的数据条数
    val numCount = count_num
    log.info("总数据条数：" + numCount)
    // 总的数据量(单位b)
    val totalData: Long = everyDataNum * numCount
    log.info("总数据量b：" + totalData)
    // 分区数为数据量(单位b)/core总数/单个task处理数据量200M(根据HDFS block块推算1-2倍大小)
    var partitions: Long = Math.ceil(Math.ceil(Math.ceil(totalData / 1024d) / 1024d) / _repartition_per_nums.toDouble).toLong
    log.info("分区数: " + partitions.toInt)
    //    //尽量保证重分区不产生额外的shuffle
    //    val actualPartitions = data.rdd.getNumPartitions
    //    if(actualPartitions < partitions){
    //      partitions = actualPartitions
    //    }

    if (partitions > _max_partition_nums) {
      //200为SparkSQL默认shuffle处理分区数 spark.sql.shuffle.partitions
      partitions = _max_partition_nums
    } else if (partitions < 1) {
      partitions = 1
    }

    log.info(s"data partitions: ${
      data.rdd.partitions.length
    }")
    log.info(s"now partitions: ${
      partitions.toInt
    }")
    partitions = Math.max(data.rdd.partitions.length, partitions.toInt)
    log.info("分区数: " + partitions.toInt)
    partitions.toInt
  }

  /**
   * 根据数据量大小和当前资源数进行重新分区
   *
   * @param kc   程序上下文
   * @param data 数据
   * @return 分区数
   */
  //, mppColsInfo: util.List[DmTableColumn]
  def rePartitions(kc: KhaosContext, data: DataFrame, mppColsInfo: util.List[DmTableColumn]): Int = {
    val _kc: KhaosContext = kc
    // 每个executor的core数
    //    val numCores = _kc.conf.getString("executor-cores", "2").toInt
    //    // 一个executor的内存大小
    //    val numMemry = _kc.conf.getString("executor-memory", "512M").toUpperCase
    //    var memory = 1.0
    //    if (numMemry.contains("M")) {
    //      memory = numMemry.split("M")(0).toDouble / 1024
    //    } else {
    //      memory = numMemry.split("G")(0).toDouble
    //    }

    // 每条数据大小 单位b
    var dataNum = 0
    // 计算出每条数据的大小
    val fieldAndLeagth = mppColsInfo.asScala.map(colEntiy => {
      var leagth: String = colEntiy.getLength
      // 浮点数为(10,2),只取10
      if (leagth.contains(",")) {
        leagth = leagth.split(",")(0)
      }
      leagth
    }).toList

    for (elem <- fieldAndLeagth) {
      // 字段没有长度,给默认50
      if (elem == null || elem == "") {
        dataNum += 50
      } else {
        dataNum += elem.toInt
      }
    }
    // 每条数据的大小大约为 kb
    val everyDataNum: Long = dataNum
    // 总的数据条数
    // val numCount = calculateDataNum(kc, data, "repatitions")._2.value.toLong
    val numCount = data.count()
    println("数据总计: " + numCount + " 条")
    // 总的数据量 单位b
    val dataL: Long = everyDataNum * numCount
    // 分区数为数据量(单位b)除以core数除以task处理数据量200M
    var partrtions: Long = Math.floorDiv(dataL, 200)
    val tempNum: Long = Math.floorDiv(partrtions, 1024)
    partrtions = Math.floorDiv(tempNum, 1024)

    if (partrtions >= 200) {
      partrtions = 200
    } else if (partrtions <= 1) {
      partrtions = 1
    }
    data.rdd.partitions.length
    log.info(s"data partitions: ${
      data.rdd.partitions.length
    }")
    log.info(s"now partitions: ${
      partrtions.toInt
    }")
    partrtions = Math.max(data.rdd.partitions.length, partrtions.toInt)
    log.info("分区数: " + partrtions.toInt)
    partrtions.toInt
  }

  /** 累加器 计算 dataframe 数据条数 */
  def calculateDataNum(kc: KhaosContext, data: DataFrame, sinkType: String): (DataFrame, LongAccumulator) = {
    val accumulator: LongAccumulator = kc.sparkSession.sparkContext.longAccumulator("DataNumber" + "_" + sinkType + "_" + Random.nextInt(1000))
    val resultData: Dataset[Row] = data.map(row => {
      accumulator.add(1)
      row
    }
    )(RowEncoder(data.schema))
    (resultData, accumulator)
  }

  def calculateDataNum(sparkSession: SparkSession, data: DataFrame, accName: String): (DataFrame, LongAccumulator) = {
    val accumulator: LongAccumulator = sparkSession.sparkContext.longAccumulator("DataNumber" + "_" + accName + "_" + Random.nextInt(1000))
    val resultData: Dataset[Row] = data.map(row => {
      accumulator.add(1)
      row
    }
    )(RowEncoder(data.schema))
    (resultData, accumulator)
  }

  /** 上报数据状态 (关系型数据库) */
  def reportDataStatusRelation(kc: KhaosContext, dataStatus: RelationDataStatusInfo, dbName: String, tblName: String, clazz: String, metaParamJson: String): Unit = {
    /*  try {
        // 生成上报的KV map
        val indicatorsMap = new mutable.HashMap[String, String]()
        var bizDate: String = kc.conf.getString(SchedulerConstants.BIZ_DATE) + " " + kc.conf.getString(SchedulerConstants.BIZ_TIME)

        val format = new SimpleDateFormat("yyyyMMdd HH:mm:ss")
        val date: util.Date = format.parse(bizDate)
        val format1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
        bizDate = format1.format(date)

        indicatorsMap.put(IndicatorsEnum.CURRENT_BTIME, bizDate)
        indicatorsMap.put(IndicatorsEnum.DATA_NUM, dataStatus.getDataNum)
        if (dataStatus != null && dataStatus.isCover != null) {
          indicatorsMap.put(IndicatorsEnum.isCover, dataStatus.isCover.toString)
        }
        //上报数据状态
        val result: ReportDataStatusResult = DataStatusUtils.reportDataStatus(kc, dbName, tblName, indicatorsMap, clazz, metaParamJson)
        //返回结果
        if (result == null) {
          log.warn("数据状态上报失败: result=null")
        } else if (result != null && result.getStatus != 200) {
          log.warn(String.format("数据状态上报失败: status: %s, errMessage: %s, result: %s ", result.getStatus.toString, result.getErrMessage, result.isResult.toString))
        } else if (result != null && result.getStatus == 200) {
          log.info(String.format("数据状态上报成功: status: %s, message: %s, result: %s ", result.getStatus.toString, result.getMessage, result.isResult.toString))
        }
      } catch {
        case e: Exception => {
          e.printStackTrace()
          log.warn(s"数据状态上报失败,失败信息 ${e.getMessage}, 失败原因: ${e.getCause}")
        }
      }*/

  }

  /** 上报数据状态 (cos) */
  def reportDataStatusCos(kc: KhaosContext, fileInfo: CosDataStatusInfo, dbName: String, tblName: String, clazz: String, metaParamJson: String): Unit = {
    /*   try {
         // 生成上报的KV map
         val indicatorsMap = new mutable.HashMap[String, String]()
         var bizDate: String = kc.conf.getString(SchedulerConstants.BIZ_DATE) + " " + kc.conf.getString(SchedulerConstants.BIZ_TIME)

         val format = new SimpleDateFormat("yyyyMMdd HH:mm:ss")
         val date: util.Date = format.parse(bizDate)
         val format1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
         bizDate = format1.format(date)
         indicatorsMap.put(IndicatorsEnum.CURRENT_BTIME, bizDate)
         indicatorsMap.put(IndicatorsEnum.FILE_NAME, fileInfo.getFileName)
         indicatorsMap.put(IndicatorsEnum.FILE_LOCATION, fileInfo.getFileLocation)
         indicatorsMap.put(IndicatorsEnum.DATA_SIZE, fileInfo.getFileSize)
         indicatorsMap.put(IndicatorsEnum.isCover, "true")
         //上报数据状态
         val result: ReportDataStatusResult = DataStatusUtils.reportDataStatus(kc, dbName, tblName, indicatorsMap, clazz, metaParamJson)
         //返回结果
         if (result == null) {
           log.warn("数据状态上报失败: result=null")
         } else if (result != null && result.getStatus != 200) {
           log.warn(String.format("数据状态上报失败: status: %s, errMessage: %s, result: %s ", result.getStatus.toString, result.getErrMessage, result.isResult.toString))
         } else if (result != null && result.getStatus == 200) {
           log.info(String.format("数据状态上报成功: status: %s, message: %s, result: %s ", result.getStatus.toString, result.getMessage, result.isResult.toString))
         }
       } catch {
         case e: Exception => {
           e.printStackTrace()
           log.warn(s"数据状态上报失败,失败信息 ${e.getMessage}, 失败原因: ${e.getCause}")
         }
       }*/

  }


  /** 上报数据状态 (ks3) */
  def reportDataStatusKs3(kc: KhaosContext, fileInfo: Ks3DataStatusInfo, dbName: String, tblName: String, clazz: String, metaParamJson: String): Unit = {
    /*  try {
        // 生成上报的KV map
        val indicatorsMap = new mutable.HashMap[String, String]()

        val bizDate: String = formatDate(kc)
        indicatorsMap.put(IndicatorsEnum.CURRENT_BTIME, bizDate)
        indicatorsMap.put(IndicatorsEnum.FILE_NAME, fileInfo.getFileName)
        indicatorsMap.put(IndicatorsEnum.FILE_LOCATION, fileInfo.getFileLocation)
        indicatorsMap.put(IndicatorsEnum.DATA_SIZE, fileInfo.getFileSize)
        indicatorsMap.put(IndicatorsEnum.isCover, "true")
        //上报数据状态
        val result: ReportDataStatusResult = DataStatusUtils.reportDataStatus(kc, dbName, tblName, indicatorsMap, clazz, metaParamJson)
        //返回结果
        if (result == null) {
          log.warn("数据状态上报失败: result=null")
        } else if (result != null && result.getStatus != 200) {
          log.warn(String.format("数据状态上报失败: status: %s, errMessage: %s, result: %s ", result.getStatus.toString, result.getErrMessage, result.isResult.toString))
        } else if (result != null && result.getStatus == 200) {
          log.info(String.format("数据状态上报成功: status: %s, message: %s, result: %s ", result.getStatus.toString, result.getMessage, result.isResult.toString))
        }
      } catch {
        case e: Exception => {
          e.printStackTrace()
          log.warn(s"数据状态上报失败,失败信息 ${e.getMessage}, 失败原因: ${e.getCause}")
        }
      }*/

  }

  /** 上报数据状态 (hdfs) */
  def reportDataStatusHdfs(kc: KhaosContext, fileInfo: HdfsDataStatusInfo, dbName: String, tblName: String, clazz: String, metaParamJson: String): Unit = {
    /*try {
      log.info("reportDataStatusHdfs:" + metaParamJson)
      // 生成上报的KV map
      val indicatorsMap = new mutable.HashMap[String, String]()
      var bizDate: String = kc.conf.getString(SchedulerConstants.BIZ_DATE) + " " + kc.conf.getString(SchedulerConstants.BIZ_TIME)

      val format = new SimpleDateFormat("yyyyMMdd HH:mm:ss")
      val date: util.Date = format.parse(bizDate)
      val format1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
      bizDate = format1.format(date)
      indicatorsMap.put(IndicatorsEnum.CURRENT_BTIME, bizDate)
      indicatorsMap.put(IndicatorsEnum.FILE_NAME, fileInfo.getFileName)
      indicatorsMap.put(IndicatorsEnum.FILE_LOCATION, fileInfo.getFileLocation)
      indicatorsMap.put(IndicatorsEnum.DATA_SIZE, fileInfo.getFileSize)
      indicatorsMap.put(IndicatorsEnum.isCover, "true")
      //上报数据状态
      val result: ReportDataStatusResult = DataStatusUtils.reportDataStatus(kc, dbName, tblName, indicatorsMap, clazz, metaParamJson)
      //返回结果
      if (result == null) {
        log.warn("数据状态上报失败: result=null")
      } else if (result != null && result.getStatus != 200) {
        log.warn(String.format("数据状态上报失败: status: %s, errMessage: %s, result: %s ", result.getStatus.toString, result.getErrMessage, result.isResult.toString))
      } else if (result != null && result.getStatus == 200) {
        log.info(String.format("数据状态上报成功: status: %s, message: %s, result: %s ", result.getStatus.toString, result.getMessage, result.isResult.toString))
      }
    } catch {
      case e: Exception => {
        e.printStackTrace()
        log.warn(s"数据状态上报失败,失败信息 ${e.getMessage}, 失败原因: ${e.getCause}")
      }
    }*/
  }

  /** 转换调度业务日期格式 */
  def formatDate(kc: KhaosContext): String = {
    val bizDate: String = kc.conf.getString(SchedulerConstants.BIZ_DATE) + " " + kc.conf.getString(SchedulerConstants.BIZ_TIME)
    val format = new SimpleDateFormat("yyyyMMdd HH:mm:ss")
    val date: util.Date = format.parse(bizDate)
    val format1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
    format1.format(date)
  }


}
